package com.jstarcraft.ai.jsat.lossfunctions;

import static java.lang.Math.abs;
import static java.lang.Math.exp;
import static java.lang.Math.log;

import com.jstarcraft.ai.jsat.classifiers.CategoricalResults;

/**
 * The LogisticLoss loss function for classification <i>L(x, y) =
 * log(1+exp(-y*x))</i>. <br>
 * This function is twice differentiable.
 *
 * @author Edward Raff
 */
public class LogisticLoss implements LossC {
    /*
     * NOTE: 30 used as a threshold b/c at the small values exp(-30) stradles the
     * edge of numerical double precision
     */

    private static final long serialVersionUID = -3929171604513497068L;

    /**
     * Computes the logistic loss
     *
     * @param pred the predicted value
     * @param y    the true value
     * @return the logistic loss
     */
    public static double loss(double pred, double y) {
        final double x = -y * pred;
        if (x >= 30)// as x -> inf, L(x) -> x. At 30 exp(x) is O(10^13), getting unstable. L(x)-x at
                    // this value is O(10^-14), also avoids exp and log ops
            return x;
        else if (x <= -30)
            return 0;
        return log(1 + exp(x));
    }

    /**
     * Computes the first derivative of the logistic loss
     *
     * @param pred the predicted value
     * @param y    the true value
     * @return the first derivative of the logistic loss
     */
    public static double deriv(double pred, double y) {
        final double x = y * pred;
        if (x >= 30)
            return 0;
        else if (x <= -30)
            return y;

        return -y / (1 + exp(y * pred));
    }

    /**
     * Computes the second derivative of the logistic loss
     *
     * @param pred the predicted value
     * @param y    the true value
     * @return the second derivative of the logistic loss
     */
    public static double deriv2(double pred, double y) {
        final double x = y * pred;
        if (x >= 30)
            return 0;
        else if (x <= -30)
            return 0;

        final double p = 1 / (1 + exp(y * pred));

        return p * (1 - p);
    }

    public static CategoricalResults classify(double score) {
        CategoricalResults cr = new CategoricalResults(2);
        final double p;
        if (score > 30)
            p = 1.0;
        else if (score < -30)
            p = 0.0;
        else
            p = 1 / (1 + Math.exp(-score));
        cr.setProb(0, 1 - p);
        cr.setProb(1, p);
        return cr;
    }

    @Override
    public double getLoss(double pred, double y) {
        return loss(pred, y);
    }

    @Override
    public double getDeriv(double pred, double y) {
        return deriv(pred, y);
    }

    @Override
    public double getDeriv2(double pred, double y) {
        return deriv2(pred, y);
    }

    @Override
    public double getDeriv2Max() {
        return 1.0 / 4.0;
    }

    @Override
    public LogisticLoss clone() {
        return this;
    }

    @Override
    public CategoricalResults getClassification(double score) {
        return classify(score);
    }

    @Override
    public double getConjugate(double b, double pred, double y) {
        if (b < 0 || b > 1)
            return Double.POSITIVE_INFINITY;
        // else
        if (0.5 - abs(0.5 - b) < 1e-13) // you are so close tot he edges of making log(b) return NaN, so lets round you
                                        // out
            return 0;// This is
        return b * log(b) + (1 - b) * log(1 - b); // log(max(1e, b)) done to avoid NaN with -0 inputs
    }

    @Override
    public double lipschitz() {
        return 4;
    }
}
